df <- read.csv("heart.csv")
head(df)
##   Age Sex ChestPainType RestingBP Cholesterol FastingBS RestingECG MaxHR
## 1  40   M           ATA       140         289         0     Normal   172
## 2  49   F           NAP       160         180         0     Normal   156
## 3  37   M           ATA       130         283         0         ST    98
## 4  48   F           ASY       138         214         0     Normal   108
## 5  54   M           NAP       150         195         0     Normal   122
## 6  39   M           NAP       120         339         0     Normal   170
##   ExerciseAngina Oldpeak ST_Slope HeartDisease
## 1              N     0.0       Up            0
## 2              N     1.0     Flat            1
## 3              N     0.0       Up            0
## 4              Y     1.5     Flat            1
## 5              N     0.0       Up            0
## 6              N     0.0       Up            0
summary(df)
##       Age            Sex            ChestPainType        RestingBP    
##  Min.   :28.00   Length:918         Length:918         Min.   :  0.0  
##  1st Qu.:47.00   Class :character   Class :character   1st Qu.:120.0  
##  Median :54.00   Mode  :character   Mode  :character   Median :130.0  
##  Mean   :53.51                                         Mean   :132.4  
##  3rd Qu.:60.00                                         3rd Qu.:140.0  
##  Max.   :77.00                                         Max.   :200.0  
##   Cholesterol      FastingBS       RestingECG            MaxHR      
##  Min.   :  0.0   Min.   :0.0000   Length:918         Min.   : 60.0  
##  1st Qu.:173.2   1st Qu.:0.0000   Class :character   1st Qu.:120.0  
##  Median :223.0   Median :0.0000   Mode  :character   Median :138.0  
##  Mean   :198.8   Mean   :0.2331                      Mean   :136.8  
##  3rd Qu.:267.0   3rd Qu.:0.0000                      3rd Qu.:156.0  
##  Max.   :603.0   Max.   :1.0000                      Max.   :202.0  
##  ExerciseAngina        Oldpeak          ST_Slope          HeartDisease   
##  Length:918         Min.   :-2.6000   Length:918         Min.   :0.0000  
##  Class :character   1st Qu.: 0.0000   Class :character   1st Qu.:0.0000  
##  Mode  :character   Median : 0.6000   Mode  :character   Median :1.0000  
##                     Mean   : 0.8874                      Mean   :0.5534  
##                     3rd Qu.: 1.5000                      3rd Qu.:1.0000  
##                     Max.   : 6.2000                      Max.   :1.0000
colSums(is.na(df))
##            Age            Sex  ChestPainType      RestingBP    Cholesterol 
##              0              0              0              0              0 
##      FastingBS     RestingECG          MaxHR ExerciseAngina        Oldpeak 
##              0              0              0              0              0 
##       ST_Slope   HeartDisease 
##              0              0
library(ggplot2)
## Warning: package 'ggplot2' was built under R version 4.5.2
library(gridExtra)

numerical_vars <- c("Age", "RestingBP", "Cholesterol", "MaxHR", "Oldpeak")

plots <- list()

histograms <- list()

for (var in numerical_vars) {
  p <- ggplot(df, aes(x = .data[[var]])) +
    geom_histogram(aes(y = ..density..), bins = 30, fill = "#69b3a2", color = "white", alpha = 0.8) +
    geom_density(color = "#D55E00", linewidth = 1.2, alpha = 0.7) +  
    theme_classic(base_size = 14) +
    labs(
      title = paste("Histogram of", var),
      x = var,
      y = "Density"
    ) +
    theme(
      plot.title = element_text(hjust = 0.5, face = "bold", size = 14),
      axis.title = element_text(size = 10),
      axis.text = element_text(size = 10),
      panel.grid.major = element_line(color = "gray90"),
      panel.grid.minor = element_blank()
    ) +
    scale_x_continuous(expand = expansion(mult = c(0.05, 0.05))) +
    scale_y_continuous(expand = expansion(mult = c(0, 0.1)))
  
  histograms[[var]] <- p
}

grid.arrange(grobs = histograms, ncol = 3)
## Warning: The dot-dot notation (`..density..`) was deprecated in ggplot2 3.4.0.
## ℹ Please use `after_stat(density)` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.

overlay_plots <- list()
outliers <- list()

for (var in numerical_vars) {
  Q1 <- quantile(df[[var]], 0.25, na.rm = TRUE)
  Q3 <- quantile(df[[var]], 0.75, na.rm = TRUE)
  IQR <- Q3 - Q1
  
  lower_bound <- Q1 - 1.5 * IQR
  upper_bound <- Q3 + 1.5 * IQR
  
  outlier_values <- df[[var]][df[[var]] < lower_bound | df[[var]] > upper_bound]
  
  outliers[[var]] <- list(
    "Lower Bound" = lower_bound,
    "Upper Bound" = upper_bound,
    "Outliers" = outlier_values
  )
  
  p <- ggplot(df, aes(x = "", y = .data[[var]])) +  # x 为 "",所有点显示在一列
    geom_violin(
      fill = "#0099cc",
      color = "black",
      alpha = 0.5 
    ) +
    geom_boxplot(
      width = 0.15,
      fill = "white",     
      outlier.shape = 16, 
      outlier.color = "red",
      outlier.size = 2    
    ) +
    theme_minimal(base_size = 14) +
    labs(
      title = paste("Violin + Boxplot of", var), 
      x = "", 
      y = var
    ) +
    theme(
      plot.title = element_text(
        hjust = 0.5,  
        face = "bold", 
        size = 12    
      ),
      axis.title = element_text(size = 10),
      axis.text = element_text(size = 8),
      panel.grid.major = element_line(color = "gray90"),
      panel.grid.minor = element_blank()  
    ) +
    scale_y_continuous(
      expand = expansion(mult = c(0.05, 0.05))  
    )
  
  overlay_plots[[var]] <- p
}

grid.arrange(grobs = overlay_plots, ncol = 3) 

library(VIM)
## Loading required package: colorspace
## Loading required package: grid
## VIM is ready to use.
## Suggestions and bug-reports can be submitted at: https://github.com/statistikat/VIM/issues
## 
## Attaching package: 'VIM'
## The following object is masked from 'package:datasets':
## 
##     sleep
outliers
## $Age
## $Age$`Lower Bound`
##  25% 
## 27.5 
## 
## $Age$`Upper Bound`
##  75% 
## 79.5 
## 
## $Age$Outliers
## integer(0)
## 
## 
## $RestingBP
## $RestingBP$`Lower Bound`
## 25% 
##  90 
## 
## $RestingBP$`Upper Bound`
## 75% 
## 170 
## 
## $RestingBP$Outliers
##  [1] 190 180 180 180 200 180 180 180  80 200 185 200 180 180   0 178 172 180 190
## [20] 174 178 180 200 192 178 180 180 172
## 
## 
## $Cholesterol
## $Cholesterol$`Lower Bound`
##    25% 
## 32.625 
## 
## $Cholesterol$`Upper Bound`
##     75% 
## 407.625 
## 
## $Cholesterol$Outliers
##   [1] 468 518 412 529 466 603 491   0   0   0   0   0   0   0   0   0   0   0
##  [19]   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
##  [37]   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
##  [55]   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
##  [73]   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
##  [91]   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
## [109]   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
## [127]   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
## [145]   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
## [163]   0   0   0   0   0   0   0   0   0   0   0 458   0   0   0   0   0   0
## [181] 564 417 409
## 
## 
## $MaxHR
## $MaxHR$`Lower Bound`
## 25% 
##  66 
## 
## $MaxHR$`Upper Bound`
## 75% 
## 210 
## 
## $MaxHR$Outliers
## [1] 63 60
## 
## 
## $Oldpeak
## $Oldpeak$`Lower Bound`
##   25% 
## -2.25 
## 
## $Oldpeak$`Upper Bound`
##  75% 
## 3.75 
## 
## $Oldpeak$Outliers
##  [1]  4.0  5.0 -2.6  4.0  4.0  4.0  4.0  4.0  4.2  4.0  5.6  3.8  4.2  6.2  4.4
## [16]  4.0
df$RestingBP[df$RestingBP == 0] <- NA
df$Cholesterol[df$Cholesterol == 0] <- NA

df <- kNN(df, k = 5, imp_var = FALSE)
##          Age  Cholesterol    FastingBS        MaxHR      Oldpeak HeartDisease 
##         28.0         85.0          0.0         60.0         -2.6          0.0 
##          Age  Cholesterol    FastingBS        MaxHR      Oldpeak HeartDisease 
##         77.0        603.0          1.0        202.0          6.2          1.0 
##          Age    RestingBP    FastingBS        MaxHR      Oldpeak HeartDisease 
##         28.0         80.0          0.0         60.0         -2.6          0.0 
##          Age    RestingBP    FastingBS        MaxHR      Oldpeak HeartDisease 
##         77.0        200.0          1.0        202.0          6.2          1.0
plots <- list()
histograms <- list()

for (var in numerical_vars) {
  p <- ggplot(df, aes(x = .data[[var]])) +
    geom_histogram(aes(y = ..density..), bins = 30, fill = "#69b3a2", color = "white", alpha = 0.8) +
    geom_density(color = "#D55E00", linewidth = 1.2, alpha = 0.7) +  
    theme_classic(base_size = 14) +
    labs(
      title = paste("Histogram of", var),
      x = var,
      y = "Density"
    ) +
    theme(
      plot.title = element_text(hjust = 0.5, face = "bold", size = 14),
      axis.title = element_text(size = 10),
      axis.text = element_text(size = 10),
      panel.grid.major = element_line(color = "gray90"),
      panel.grid.minor = element_blank()
    ) +
    scale_x_continuous(expand = expansion(mult = c(0.05, 0.05))) +
    scale_y_continuous(expand = expansion(mult = c(0, 0.1)))
  
  histograms[[var]] <- p
}

grid.arrange(grobs = histograms, ncol = 3)

library(GGally)

numerical_vars2 <- df[c("HeartDisease", "Age", "RestingBP", "Cholesterol", "MaxHR", "Oldpeak")]

scatter_plot_matrix <- ggpairs(
  numerical_vars2,
  aes(color = as.factor(HeartDisease), alpha = 0.7),
  lower = list(
    continuous = wrap("points", alpha = 0.7)
  ),
  diag = list(
    continuous = wrap("densityDiag", alpha = 0.6) 
  ),
  upper = list(
    continuous = wrap("cor", size = 4, alignPercent = 0.5)
  ),
  title = "Scatter Plot Matrix of Numerical Variables Grouped by HeartDisease"
)

print(scatter_plot_matrix)
## Warning: There were 2 warnings in `summarise()`.
## The first warning was:
## ℹ In argument: `text = text_fn(.data$x, .data$y)`.
## ℹ In group 1: `color = 0`.
## Caused by warning in `cor()`:
## ! the standard deviation is zero
## ℹ Run `dplyr::last_dplyr_warnings()` to see the 1 remaining warning.
## There were 2 warnings in `summarise()`.
## The first warning was:
## ℹ In argument: `text = text_fn(.data$x, .data$y)`.
## ℹ In group 1: `color = 0`.
## Caused by warning in `cor()`:
## ! the standard deviation is zero
## ℹ Run `dplyr::last_dplyr_warnings()` to see the 1 remaining warning.
## There were 2 warnings in `summarise()`.
## The first warning was:
## ℹ In argument: `text = text_fn(.data$x, .data$y)`.
## ℹ In group 1: `color = 0`.
## Caused by warning in `cor()`:
## ! the standard deviation is zero
## ℹ Run `dplyr::last_dplyr_warnings()` to see the 1 remaining warning.
## There were 2 warnings in `summarise()`.
## The first warning was:
## ℹ In argument: `text = text_fn(.data$x, .data$y)`.
## ℹ In group 1: `color = 0`.
## Caused by warning in `cor()`:
## ! the standard deviation is zero
## ℹ Run `dplyr::last_dplyr_warnings()` to see the 1 remaining warning.
## There were 2 warnings in `summarise()`.
## The first warning was:
## ℹ In argument: `text = text_fn(.data$x, .data$y)`.
## ℹ In group 1: `color = 0`.
## Caused by warning in `cor()`:
## ! the standard deviation is zero
## ℹ Run `dplyr::last_dplyr_warnings()` to see the 1 remaining warning.

library(dplyr)
## 
## Attaching package: 'dplyr'
## The following object is masked from 'package:gridExtra':
## 
##     combine
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
categorical_vars <- c("Sex", "ChestPainType", "FastingBS","RestingECG", "ExerciseAngina", "ST_Slope")

category_plots <- list()

for (var in categorical_vars) {
  plot_data <- df %>%
    group_by(.data[[var]], HeartDisease) %>%
    summarise(Count = n(), .groups = "drop") %>%
    mutate(Percent = Count / sum(Count) * 100) 
  p <- ggplot(plot_data, aes(x = .data[[var]], y = Count, fill = factor(HeartDisease))) +
    geom_bar(stat = "identity", position = "dodge", alpha = 0.8) + 
    geom_text(
      aes(label = paste0(round(Percent, 1), "%")), 
      position = position_dodge(width = 0.9), vjust = 0.5, size = 2 
    ) +
    labs(
      title = paste("Heart Disease by", var),
      x = var,
      y = "Count", 
      fill = "Heart Disease"
    ) +
    theme_minimal(base_size = 12) +
    theme(
      plot.title = element_text(hjust = 0.5, face = "bold", size = 10), 
      axis.title = element_text(size = 8), 
      axis.text = element_text(size = 8),
      legend.position = "right" 
    ) +
    scale_fill_manual(values = c("#96ceb4", "#ffef96"))
  category_plots[[var]] <- p 
}

total_plots <- length(categorical_vars)
ncol_set <- 3 
nrow_set <- ceiling(total_plots / ncol_set)

grid.arrange(
  grobs = category_plots,
  ncol = ncol_set, 
  nrow = nrow_set,
  heights = rep(1, nrow_set),
  top = textGrob("Categorical Variables Distribution", gp = gpar(fontsize = 15, fontface = "bold"))
)

library(corrplot)
## corrplot 0.95 loaded
target <- "HeartDisease"

categorical_corr_results <- lapply(categorical_vars, function(var) {
  ct <- table(df[[var]], df[[target]])
  chi_test <- chisq.test(ct)
  list(
    Variable = var,
    Chi_Square = chi_test$statistic,
    P_Value = chi_test$p.value
  ) 
})

for (result in categorical_corr_results) {
  print(result)
}
## $Variable
## [1] "Sex"
## 
## $Chi_Square
## X-squared 
##   84.1451 
## 
## $P_Value
## [1] 4.597617e-20
## 
## $Variable
## [1] "ChestPainType"
## 
## $Chi_Square
## X-squared 
##  268.0672 
## 
## $P_Value
## [1] 8.083728e-58
## 
## $Variable
## [1] "FastingBS"
## 
## $Chi_Square
## X-squared 
##  64.32068 
## 
## $P_Value
## [1] 1.057302e-15
## 
## $Variable
## [1] "RestingECG"
## 
## $Chi_Square
## X-squared 
##  10.93147 
## 
## $P_Value
## [1] 0.004229233
## 
## $Variable
## [1] "ExerciseAngina"
## 
## $Chi_Square
## X-squared 
##  222.2594 
## 
## $P_Value
## [1] 2.907808e-50
## 
## $Variable
## [1] "ST_Slope"
## 
## $Chi_Square
## X-squared 
##  355.9184 
## 
## $P_Value
## [1] 5.167638e-78
library(fastDummies)
library(pheatmap)
library(dplyr)

df_dummied <- dummy_cols(df, remove_selected_columns = TRUE, remove_first_dummy = TRUE)

fig <- function(x, y){
  options(repr.plot.width = x, repr.plot.height = y)
}

fig(15, 15)

df_cor <- cor(df |> select(where(is.numeric)), use = "pairwise.complete.obs", method = 'spearman')
pheatmap(df_cor, 
        display_numbers = TRUE, 
        main = "Correlation Heatmap", 
        fontsize_col = 20, fontsize_row = 20, fontsize = 20, fontsize_number = 15,
        breaks = seq(-1, 1, by = 0.02),
        color = colorRampPalette(c("steelblue", "white", "firebrick"))(100)
        )

categorical_vars <- c("Sex", "ChestPainType", "FastingBS", "RestingECG", "ExerciseAngina", "ST_Slope")
cat("Reference Levels (The categories dropped by remove_first_dummy = TRUE):\n")
## Reference Levels (The categories dropped by remove_first_dummy = TRUE):
cat("---------------------------------------------------------------------\n")
## ---------------------------------------------------------------------
for (var in categorical_vars) {
  ref_level <- levels(factor(df[[var]]))[1]
  cat(sprintf("Variable: %-15s | Reference Class: %s\n", var, ref_level))
}
## Variable: Sex             | Reference Class: F
## Variable: ChestPainType   | Reference Class: ASY
## Variable: FastingBS       | Reference Class: 0
## Variable: RestingECG      | Reference Class: LVH
## Variable: ExerciseAngina  | Reference Class: N
## Variable: ST_Slope        | Reference Class: Down
library(caret)
## Loading required package: lattice
set.seed(42)
train_index <- createDataPartition(y=as.factor(df_dummied$HeartDisease), p=0.8, list=FALSE)

train_set <- df_dummied[train_index, ]
test_set <- df_dummied[-train_index, ]
evaluate_model <- function(model, test_data, outcome, threshold = 0.5, 
                           glmnet_newx = NULL, glmnet_s = "lambda.min") {
  
  # --- 1. Generate Probabilities ---
  if (inherits(model, "train")) {
    pred_probs <- predict(model, test_data, type = "prob")
    pos_col <- intersect(names(pred_probs), c("X1", "Yes", "1", "Positive"))
    if (length(pos_col) > 0) {
      probs <- pred_probs[[pos_col[1]]]
    } else {
      probs <- pred_probs[, 2]
    }
  } else if (inherits(model, "cv.glmnet") || inherits(model, "glmnet")) {
    # Ensure explicit numeric vector conversion for Lasso matrices
    probs <- as.vector(predict(model, newx = glmnet_newx, s = glmnet_s, type = "response"))
  } else {
    probs <- predict(model, test_data, type = "response")
  }

  # --- 2. Prepare Truth Vector (0/1) ---
  if (is.character(outcome) && length(outcome) == 1 && is.data.frame(test_data)) {
    truth_vec <- test_data[[outcome]]
  } else {
    truth_vec <- outcome
  }

  # Robust conversion to 0/1 numeric
  if (is.factor(truth_vec)) {
    # If levels are "X0", "X1" or "0", "1", this maps the first level to 0 and second to 1
    truth_num <- as.numeric(truth_vec) - 1 
  } else {
    truth_num <- as.numeric(truth_vec)
  }

  # --- 3. Calculate Metrics ---
  pred_class <- ifelse(probs > threshold, 1, 0)
  pred_factor <- factor(pred_class, levels = c(0, 1))
  truth_factor <- factor(truth_num, levels = c(0, 1))

  cm <- caret::confusionMatrix(pred_factor, truth_factor, mode = "everything", positive = "1")

  accuracy    <- cm$overall["Accuracy"]
  precision   <- cm$byClass["Precision"]
  recall      <- cm$byClass["Recall"]
  f1_score    <- cm$byClass["F1"]
  brier_score <- mean((probs - truth_num)^2)

  # --- 4. ROC & PR Curves ---
  roc_obj <- pROC::roc(truth_num, probs, quiet = TRUE)
  auc_val <- pROC::auc(roc_obj)
  roc_df <- data.frame(TPR = roc_obj$sensitivities, FPR = 1 - roc_obj$specificities)

  p_roc <- ggplot2::ggplot(roc_df, ggplot2::aes(x = FPR, y = TPR)) +
    ggplot2::geom_line(color = "blue", linewidth = 1) +
    ggplot2::geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "gray") +
    ggplot2::labs(title = paste0("ROC (AUC = ", round(auc_val, 3), ")"), x = "1 - Specificity", y = "Sensitivity") +
    ggplot2::theme_minimal()

  pred_obj <- ROCR::prediction(probs, truth_num)
  perf_pr <- ROCR::performance(pred_obj, "prec", "rec")
  pr_df <- data.frame(Recall = perf_pr@x.values[[1]], Precision = perf_pr@y.values[[1]])
  pr_df <- na.omit(pr_df)
  auc_pr <- ROCR::performance(pred_obj, "aucpr")@y.values[[1]]

  p_pr <- ggplot2::ggplot(pr_df, ggplot2::aes(x = Recall, y = Precision)) +
    ggplot2::geom_line(color = "darkgreen", linewidth = 1) +
    ggplot2::labs(title = paste0("PR Curve (AUC = ", round(auc_pr, 3), ")"), x = "Recall", y = "Precision") +
    ggplot2::theme_minimal()

  # --- 5. Calibration Plot ---
  cal_df <- data.frame(truth = truth_num, prob = probs)
  cal_df$bin <- cut(cal_df$prob, breaks = seq(0, 1, 0.1), include.lowest = TRUE)
  cal_summary <- dplyr::group_by(cal_df, bin) |> 
    dplyr::summarise(mean_prob = mean(prob), obs_frac = mean(truth), n = dplyr::n()) |> 
    na.omit()

  p_cal <- ggplot2::ggplot(cal_summary, ggplot2::aes(x = mean_prob, y = obs_frac)) +
    ggplot2::geom_point(ggplot2::aes(size = n), alpha = 0.7) + ggplot2::geom_line() +
    ggplot2::geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "red") +
    ggplot2::labs(title = "Calibration", x = "Predicted Prob", y = "Observed Fraction") +
    ggplot2::xlim(0, 1) + ggplot2::ylim(0, 1) + ggplot2::theme_minimal()

  # --- 6. Confusion Matrix Plot ---
  cm_df <- as.data.frame(cm$table)
  cm_df$Prediction <- factor(cm_df$Prediction, levels = c(1, 0))
  p_cm <- ggplot2::ggplot(cm_df, ggplot2::aes(x = Reference, y = Prediction, fill = Freq)) +
    ggplot2::geom_tile() + ggplot2::geom_text(ggplot2::aes(label = Freq), color = "white", size = 6) +
    ggplot2::labs(title = "Confusion Matrix") + ggplot2::theme_minimal() + ggplot2::theme(legend.position = "none")

  gridExtra::grid.arrange(p_roc, p_pr, p_cal, p_cm, ncol = 2)

  # --- 7. Print Metrics to Console ---
  cat("Accuracy:    ", round(accuracy, 4), "\n")
  cat("ROC AUC:     ", round(auc_val, 4), "\n")
  cat("Brier Score: ", round(brier_score, 4), "\n") 

  list(plots = list(p_roc, p_pr, p_cal, p_cm),
       metrics = cm$byClass,
       extras = list(Accuracy = accuracy, Brier = brier_score, AUC_ROC = as.numeric(auc_val), AUC_PR = as.numeric(auc_pr)))
}
library(naivebayes)
## naivebayes 1.0.0 loaded
## For more information please visit:
## https://majkamichal.github.io/naivebayes/
df_nb <- df

cat_vars <- c("Sex", "ChestPainType", "FastingBS", "RestingECG", "ExerciseAngina", "ST_Slope", "HeartDisease")
df_nb[cat_vars] <- lapply(df_nb[cat_vars], as.factor)

levels(df_nb$HeartDisease) <- make.names(levels(df_nb$HeartDisease))

train_set_nb <- df_nb[train_index, ]
test_set_nb  <- df_nb[-train_index, ]

cv_control <- trainControl(
  method = "cv",            
  number = 5,               
  classProbs = TRUE,       
  summaryFunction = prSummary, 
  verboseIter = FALSE,       
  allowParallel = TRUE      
)

cv_control_nb <- cv_control

nb_grid <- expand.grid(
  usekernel = c(TRUE, FALSE),
  laplace = c(0, 0.5, 1),
  adjust = c(0.75, 1, 1.25)
)

set.seed(42)
nb_tuned <- train(
  HeartDisease ~ ., 
  data = train_set_nb,
  method = "naive_bayes",
  trControl = cv_control_nb,
  tuneGrid = nb_grid,
  metric = "F" 
)

print(nb_tuned)
## Naive Bayes 
## 
## 735 samples
##  11 predictor
##   2 classes: 'X0', 'X1' 
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 587, 589, 588, 589, 587 
## Resampling results across tuning parameters:
## 
##   usekernel  laplace  adjust  AUC        Precision  Recall     F        
##   FALSE      0.0      0.75    0.8840004  0.8423824  0.8442890  0.8430958
##   FALSE      0.0      1.00    0.8840004  0.8423824  0.8442890  0.8430958
##   FALSE      0.0      1.25    0.8840004  0.8423824  0.8442890  0.8430958
##   FALSE      0.5      0.75    0.8840004  0.8423824  0.8442890  0.8430958
##   FALSE      0.5      1.00    0.8840004  0.8423824  0.8442890  0.8430958
##   FALSE      0.5      1.25    0.8840004  0.8423824  0.8442890  0.8430958
##   FALSE      1.0      0.75    0.8840004  0.8423824  0.8442890  0.8430958
##   FALSE      1.0      1.00    0.8840004  0.8423824  0.8442890  0.8430958
##   FALSE      1.0      1.25    0.8840004  0.8423824  0.8442890  0.8430958
##    TRUE      0.0      0.75    0.8783013  0.8764982  0.7591142  0.8131643
##    TRUE      0.0      1.00    0.8773238  0.8711361  0.7621911  0.8128045
##    TRUE      0.0      1.25    0.8797852  0.8711361  0.7621911  0.8128045
##    TRUE      0.5      0.75    0.8783013  0.8764982  0.7591142  0.8131643
##    TRUE      0.5      1.00    0.8773238  0.8711361  0.7621911  0.8128045
##    TRUE      0.5      1.25    0.8797852  0.8711361  0.7621911  0.8128045
##    TRUE      1.0      0.75    0.8783013  0.8764982  0.7591142  0.8131643
##    TRUE      1.0      1.00    0.8773238  0.8711361  0.7621911  0.8128045
##    TRUE      1.0      1.25    0.8797852  0.8711361  0.7621911  0.8128045
## 
## F was used to select the optimal model using the largest value.
## The final values used for the model were laplace = 0, usekernel = FALSE
##  and adjust = 0.75.
evaluate_model(
  model = nb_tuned, 
  test_data = test_set_nb, 
  outcome = "HeartDisease"
)

## Accuracy:     0.8689 
## ROC AUC:      0.926 
## Brier Score:  0.1203
## $plots
## $plots[[1]]

## 
## $plots[[2]]

## 
## $plots[[3]]

## 
## $plots[[4]]

## 
## 
## $metrics
##          Sensitivity          Specificity       Pos Pred Value 
##            0.8712871            0.8658537            0.8888889 
##       Neg Pred Value            Precision               Recall 
##            0.8452381            0.8888889            0.8712871 
##                   F1           Prevalence       Detection Rate 
##            0.8800000            0.5519126            0.4808743 
## Detection Prevalence    Balanced Accuracy 
##            0.5409836            0.8685704 
## 
## $extras
## $extras$Accuracy
##  Accuracy 
## 0.8688525 
## 
## $extras$Brier
## [1] 0.1202935
## 
## $extras$AUC_ROC
## [1] 0.9259841
## 
## $extras$AUC_PR
## [1] 0.9355863
varImp_nb <- varImp(nb_tuned)
plot(varImp_nb, main = "Naive Bayes Variable Importance")

library(MLmetrics)
## 
## Attaching package: 'MLmetrics'
## The following objects are masked from 'package:caret':
## 
##     MAE, RMSE
## The following object is masked from 'package:base':
## 
##     Recall
y_train <- factor(train_set$HeartDisease)
levels(y_train) <- make.names(levels(y_train))

y_test <- factor(test_set$HeartDisease)
levels(y_test) <- make.names(levels(y_test))

X_train <- model.matrix(HeartDisease ~ . -1, data = train_set)
X_test <- model.matrix(HeartDisease ~ . -1, data = test_set)

xgb_grid <- expand.grid(
  nrounds = c(150, 100, 50),         
  max_depth = c(2, 3, 5),         
  eta = c(0.05, 0.1, 0.3),   
  gamma = c(0, 0.1, 0.5, 1), 
  colsample_bytree = 0.8, 
  min_child_weight = 1,      
  subsample = 0.8  
)

set.seed(42)

xgb_tuned <- train(
  x = X_train, 
  y = y_train,
  method = "xgbTree",
  trControl = cv_control,
  tuneGrid = xgb_grid,
  metric = "F",
  verbosity = 0
)
library(pROC)
## Type 'citation("pROC")' for a citation.
## 
## Attaching package: 'pROC'
## The following object is masked from 'package:colorspace':
## 
##     coords
## The following objects are masked from 'package:stats':
## 
##     cov, smooth, var
library(ROCR)
library(gridExtra)

evaluate_model(xgb_tuned, X_test, test_set$HeartDisease)

## Accuracy:     0.8743 
## ROC AUC:      0.9393 
## Brier Score:  0.0918
## $plots
## $plots[[1]]

## 
## $plots[[2]]

## 
## $plots[[3]]

## 
## $plots[[4]]

## 
## 
## $metrics
##          Sensitivity          Specificity       Pos Pred Value 
##            0.8910891            0.8536585            0.8823529 
##       Neg Pred Value            Precision               Recall 
##            0.8641975            0.8823529            0.8910891 
##                   F1           Prevalence       Detection Rate 
##            0.8866995            0.5519126            0.4918033 
## Detection Prevalence    Balanced Accuracy 
##            0.5573770            0.8723738 
## 
## $extras
## $extras$Accuracy
##  Accuracy 
## 0.8743169 
## 
## $extras$Brier
## [1] 0.09182656
## 
## $extras$AUC_ROC
## [1] 0.9392659
## 
## $extras$AUC_PR
## [1] 0.9374764
library(SHAPforxgboost)

shap_long <- shap.prep(
  xgb_model = xgb_tuned$finalModel, 
  X_train = X_train,
  top_n = 15
)

shap.plot.summary(shap_long) + 
  scale_color_gradient(low = "#0072B2", high = "#D55E00", 
                       name = "Feature Value") +
  
  labs(
    title = "XGBoost: Feature Impact (SHAP Values)",
    subtitle = "Top 15 Most Important Predictors",
    x = "SHAP Value (Impact on Log-Odds)"
  ) +
  theme_bw() +
  theme(
    legend.position = "right",
    axis.text.y = element_text(size = 10, face = "bold"),
    plot.title = element_text(face = "bold")
  )
## Scale for colour is already present.
## Adding another scale for colour, which will replace the existing scale.

results_by_age <- function(data, model, age_col, target_col, age_groups) {
  metrics <- data.frame()
  
  for (group in age_groups) {
    if (group == "<45") {
      subset_data <- data[data[[age_col]] < 45, ]
    } else if (group == "45-60") {
      subset_data <- data[data[[age_col]] >= 45 & data[[age_col]] <= 60, ]
    } else if (group == ">60") {
      subset_data <- data[data[[age_col]] > 60, ]
    }
    
    if (nrow(subset_data) == 0) {
      warning(paste("Group", group, "has no samples!"))
      next
    }
    
    print(paste("Group:", group))
    print(table(subset_data$HeartDisease))

    y_subset <- factor(subset_data[[target_col]], levels = c("X0", "X1"))
    print(table(y_subset)) 
    
    if (length(unique(y_subset)) < 2) {
      warning(paste("Group", group, "has insufficient categories!"))
      next
    }
    
    X_subset <- model.matrix(as.formula(paste(target_col, "~ . -1")), data = subset_data)
    
    pred_probs <- predict(model, newdata = X_subset, type = "prob")
    pred_class <- ifelse(pred_probs[, 2] > 0.5, "X1", "X0")
    pred_class <- factor(pred_class, levels = levels(y_subset))
    
    print(table(pred_class))

    confusion <- caret::confusionMatrix(pred_class, y_subset)
    metrics <- rbind(metrics, data.frame(
      AgeGroup = group,
      Accuracy = confusion$overall["Accuracy"],
      Recall = confusion$byClass["Recall"],
      Precision = confusion$byClass["Precision"],
      F1 = confusion$byClass["F1"]
    ))
  }
  
  return(metrics)
}
age_groups <- c("<45", "45-60", ">60")

test_set$HeartDisease <- factor(test_set$HeartDisease, levels = c(0, 1), labels = c("X0", "X1"))

age_results <- results_by_age(test_set, xgb_tuned, "Age", "HeartDisease", age_groups)
## [1] "Group: <45"
## 
## X0 X1 
## 22  8 
## y_subset
## X0 X1 
## 22  8 
## pred_class
## X0 X1 
## 23  7 
## [1] "Group: 45-60"
## 
## X0 X1 
## 50 50 
## y_subset
## X0 X1 
## 50 50 
## pred_class
## X0 X1 
## 49 51 
## [1] "Group: >60"
## 
## X0 X1 
## 10 43 
## y_subset
## X0 X1 
## 10 43 
## pred_class
## X0 X1 
##  9 44
print(age_results)
##           AgeGroup  Accuracy    Recall Precision        F1
## Accuracy       <45 0.9000000 0.9545455 0.9130435 0.9333333
## Accuracy1    45-60 0.8300000 0.8200000 0.8367347 0.8282828
## Accuracy2      >60 0.9433962 0.8000000 0.8888889 0.8421053
adjust_threshold <- function(probabilities, labels, threshold) {
  predicted_class <- ifelse(probabilities[, 2] >= threshold, "X1", "X0")
  confusion <- caret::confusionMatrix(factor(predicted_class, levels = c("X0", "X1")),
                                      factor(labels, levels = c("X0", "X1")))
  return(list(
    Accuracy = confusion$overall["Accuracy"],
    Recall = confusion$byClass["Recall"],
    Precision = confusion$byClass["Precision"],
    F1 = confusion$byClass["F1"]
  ))
}

probabilities <- predict(xgb_tuned, X_test, type = "prob")
labels <- test_set$HeartDisease

thresholds <- seq(0.4, 0.6, by = 0.01)
results <- lapply(thresholds, function(thr) adjust_threshold(probabilities, labels, thr))

threshold_adjustment <- data.frame(
  Threshold = thresholds,
  Accuracy = sapply(results, function(x) x$Accuracy), 
  Recall = sapply(results, function(x) x$Recall), 
  Precision = sapply(results, function(x) x$Precision), 
  F1 = sapply(results, function(x) x$F1)
)

ggplot(threshold_adjustment, aes(x = Threshold)) + 
  geom_line(aes(y = Accuracy, color = "Accuracy")) + 
  geom_line(aes(y = Recall, color = "Recall")) + 
  geom_line(aes(y = Precision, color = "Precision")) + 
  geom_line(aes(y = F1, color = "F1")) +
  labs(title = "Performance Metrics vs Threshold", y = "Metric Value", x = "Threshold") +
  theme_minimal()

library(glmnet)
## Loading required package: Matrix
## Loaded glmnet 4.1-10
set.seed(42)
y_train <- factor(train_set$HeartDisease)
levels(y_train) <- make.names(levels(y_train))

y_test <- factor(test_set$HeartDisease)
levels(y_test) <- make.names(levels(y_test))

X_train <- model.matrix(HeartDisease ~ . -1, data = train_set)
X_test <- model.matrix(HeartDisease ~ . -1, data = test_set)

lasso_model <- cv.glmnet(
  X_train, 
  y_train, 
  family = "binomial",
  alpha = 1, 
  standardize = TRUE,
  nlambda = 100,
  type.measure = "class",
  nfolds = 10
)

lasso_coefficients <- as.matrix(coef(lasso_model, s = "lambda.min"))

non_zero_indices <- which(lasso_coefficients != 0)
intercept <- lasso_coefficients[1]
selected_features <- rownames(lasso_coefficients)[non_zero_indices]
selected_coefficients <- lasso_coefficients[non_zero_indices]

formula_string <- paste0("logit(P(Y=1)) = ", round(intercept, 4))
for (i in 2:length(selected_features)) {
  formula_string <- paste0(formula_string, 
                           ifelse(selected_coefficients[i] > 0, " + ", " - "), 
                           abs(round(selected_coefficients[i], 4)), 
                           " * ", 
                           selected_features[i])
}

cat(formula_string, "\n")
## logit(P(Y=1)) = -0.0014 + 0.7993 * FastingBS - 0.006 * MaxHR + 0.207 * Oldpeak + 0.9079 * Sex_M - 1.1793 * ChestPainType_ATA - 0.9266 * ChestPainType_NAP - 0.4739 * ChestPainType_TA + 0.8862 * ExerciseAngina_Y + 0.9971 * ST_Slope_Flat - 0.8801 * ST_Slope_Up
res <- evaluate_model(
  model       = lasso_model,
  test_data   = test_set,
  outcome     = "HeartDisease",
  glmnet_newx = X_test,
  glmnet_s    = "lambda.min"
)

## Accuracy:     0.8634 
## ROC AUC:      0.9292 
## Brier Score:  0.0989
library(broom)
lasso_coefs <- as.matrix(coef(lasso_model, s = "lambda.min"))
selected_vars <- rownames(lasso_coefs)[which(lasso_coefs != 0)]
selected_vars <- selected_vars[selected_vars != "(Intercept)"]

model_data <- data.frame(y =y_test, X_test)

final_glm <- glm(y ~ ., data = model_data, family = "binomial")
lasso_results <- tidy(final_glm, conf.int = TRUE, conf.level = 0.95) %>%
  filter(term != "(Intercept)") %>%
  mutate(
    # Transform Log-Odds to Probability (p)
    OR = exp(estimate),
    conf.low = exp(conf.low),
    conf.high = exp(conf.high)
  )

print(lasso_results)
## # A tibble: 15 × 8
##    term           estimate std.error statistic p.value conf.low conf.high     OR
##    <chr>             <dbl>     <dbl>     <dbl>   <dbl>    <dbl>     <dbl>  <dbl>
##  1 Age             0.0804    0.0364     2.21   0.0272  1.01         1.17  1.08  
##  2 RestingBP      -0.00838   0.0148    -0.566  0.571   0.963        1.02  0.992 
##  3 Cholesterol     0.00812   0.00593    1.37   0.171   0.997        1.02  1.01  
##  4 FastingBS       1.61      0.725      2.22   0.0262  1.27        22.5   5.01  
##  5 MaxHR          -0.0178    0.0127    -1.39   0.163   0.957        1.01  0.982 
##  6 Oldpeak         0.0352    0.319      0.110  0.912   0.550        1.95  1.04  
##  7 Sex_M           2.00      0.735      2.72   0.00661 1.87        34.4   7.36  
##  8 ChestPainType… -2.55      0.865     -2.94   0.00323 0.0124       0.389 0.0783
##  9 ChestPainType… -1.83      0.757     -2.42   0.0157  0.0335       0.677 0.161 
## 10 ChestPainType…  0.180     1.09       0.165  0.869   0.143       10.8   1.20  
## 11 RestingECG_No…  1.29      0.745      1.73   0.0838  0.866       16.7   3.62  
## 12 RestingECG_ST   0.0519    0.845      0.0614 0.951   0.199        5.70  1.05  
## 13 ExerciseAngin…  0.364     0.654      0.556  0.578   0.386        5.15  1.44  
## 14 ST_Slope_Flat  -1.51      1.48      -1.02   0.307   0.00635      2.79  0.221 
## 15 ST_Slope_Up    -3.97      1.56      -2.54   0.0110  0.000451     0.271 0.0189
lasso_results |>
  filter(conf.high < 1000) |>
  ggplot(aes(x = OR, y = reorder(term, OR))) +
    geom_point(size = 3, color = "blue") +
    geom_errorbarh(aes(xmin = conf.low, xmax = conf.high), height = 0.2) +
    geom_vline(xintercept = 1, linetype = "dashed", color = "red") +
    labs(
      title = "Top 15 Risk Factors (Odds Ratios)",
      subtitle = "Right of Red Line = Increased Risk | Left = Protective",
      x = "Odds Ratio (95% CI)",
      y = "Risk Factor"
    ) +
    theme_minimal()
## Warning: `geom_errorbarh()` was deprecated in ggplot2 4.0.0.
## ℹ Please use the `orientation` argument of `geom_errorbar()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## `height` was translated to `width`.

coef_df <- data.frame(
  Feature = rownames(lasso_coefs),
  Coefficient = lasso_coefs[, 1]
)

scorecard <- coef_df %>%
  filter(Coefficient != 0) %>%
  mutate(
    Points = round(Coefficient * 1000, 0),
    Type = ifelse(Feature == "(Intercept)", "Base Score", 
           ifelse(grepl("Age|RestingBP|Cholesterol|MaxHR|Oldpeak|FastingBS", Feature), 
                  "Continuous (Multiply by Value)", "Binary (Add if True)"))
  ) %>%
  arrange(desc(abs(Coefficient)))

scorecard
##                             Feature  Coefficient Points
## ChestPainType_ATA ChestPainType_ATA -1.179346228  -1179
## ST_Slope_Flat         ST_Slope_Flat  0.997077640    997
## ChestPainType_NAP ChestPainType_NAP -0.926645318   -927
## Sex_M                         Sex_M  0.907866031    908
## ExerciseAngina_Y   ExerciseAngina_Y  0.886181302    886
## ST_Slope_Up             ST_Slope_Up -0.880054272   -880
## FastingBS                 FastingBS  0.799331553    799
## ChestPainType_TA   ChestPainType_TA -0.473938475   -474
## Oldpeak                     Oldpeak  0.206963088    207
## MaxHR                         MaxHR -0.005960471     -6
## (Intercept)             (Intercept) -0.001444751     -1
##                                             Type
## ChestPainType_ATA           Binary (Add if True)
## ST_Slope_Flat               Binary (Add if True)
## ChestPainType_NAP           Binary (Add if True)
## Sex_M                       Binary (Add if True)
## ExerciseAngina_Y            Binary (Add if True)
## ST_Slope_Up                 Binary (Add if True)
## FastingBS         Continuous (Multiply by Value)
## ChestPainType_TA            Binary (Add if True)
## Oldpeak           Continuous (Multiply by Value)
## MaxHR             Continuous (Multiply by Value)
## (Intercept)                           Base Score
test_scores <- predict(lasso_model, newx = X_test, s = "lambda.min", type = "link")
test_probs <- predict(lasso_model, newx = X_test, s = "lambda.min", type = "response")
patient_scores <- data.frame(
  Patient_ID = rownames(test_set),
  Actual_Status = test_set$HeartDisease,
  Risk_Score = round(as.numeric(test_scores) * 1000, 0),
  Probability = round(as.numeric(test_probs), 4)
)
head(patient_scores)
##   Patient_ID Actual_Status Risk_Score Probability
## 1          4            X1       1549      0.8247
## 2         11            X0      -2655      0.0657
## 3         17            X1        914      0.7138
## 4         23            X0      -3038      0.0457
## 5         26            X0      -1961      0.1233
## 6         30            X0      -2017      0.1174
ggplot(patient_scores, aes(x = Risk_Score, y = Probability)) +
  geom_point(aes(color = factor(Actual_Status)), alpha = 0.6, size = 2) +
  stat_function(fun = function(x) 1 / (1 + exp(-x)), color = "black", linewidth = 1) +
  
  geom_vline(xintercept = 0, linetype = "dashed", color = "gray50") +
  geom_hline(yintercept = 0.5, linetype = "dashed", color = "gray50") +
  
  scale_color_manual(values = c("#96ceb4", "#ff6f69"), labels = c("Healthy", "Disease")) +
  labs(
    title = "Scorecard Risk Curve",
    subtitle = "Higher Score = Higher Probability of Heart Disease",
    x = "Total Risk Score (Log-Odds)",
    y = "Predicted Probability",
    color = "Actual Outcome"
  ) +
  theme_minimal(base_size = 14) +
  annotate("text", x = -2, y = 0.1, label = "Low Risk Zone", color = "darkgreen") +
  annotate("text", x = 2, y = 0.9, label = "High Risk Zone", color = "darkred")